/**
 * @file KdTree.cpp
 * @author enemy1205 (enemy1205@qq.com)
 * @brief TreeNode以及Kd树的构建
 * @date 2021-09-24
 */
#include "KNN.h"

TreeNode::TreeNode(const vector<double> &coord, const uint &axis, const uint &ind) : axis(axis), index(ind), lchild(nullptr), rchild(nullptr)
{
    //多维坐标vector给成员数组赋值
    assert(!coord.empty());
    coordinate = coord;
}

/**
 * @brief 为每一个数据做标签
 * @param Data 输入数据
 */
KdTree::KdTree(const vector<vector<double>> &Data)
{
    assert(!Data.empty());
    assert(!Data[0].empty());
    vector<pair<vector<double>, uint>> DataWithInd;
    for (uint i = 0; i < Data.size(); i++)
    {
        DataWithInd.emplace_back(make_pair(Data[i], i));
    }
    dim = Data[0].size();
    //根据栈先进后出的原则，返回的是根节点
    tree = buildKdTree(DataWithInd.begin(), DataWithInd.end());
}

/**
 * @brief 构建Kd树
 * @param l 左迭代器
 * @param r 右迭代器
 * @return Kd树节点
 */
TreeNode *KdTree::buildKdTree(const vector<pair<vector<double>, uint>>::iterator &l,
                              const vector<pair<vector<double>, uint>>::iterator &r)
{ //首末相减即vector长度,数据个数
    uint Size = r - l;
    //寻找最大方差的维度
    if (Size > 1)
    {
        uint axis = 0;
        double max_var = 0;
        for (uint i = 0; i < dim; i++)
        {
            double temp = Math::getVariance(l, r, i);
            if (temp > max_var)
            {
                max_var = temp;
                axis = i;
            }
        }
        //在方差最大的那一维度升序排序
        sort(l, r, [&axis](const pair<vector<double>, uint> &x, const pair<vector<double>, uint> &y)
             { return (x.first)[axis] < (y.first)[axis]; });
        //取中间元素作为父类节点
        auto element = (l + Size / 2);
        auto *t = new TreeNode((*element).first, axis, (*element).second);
        //递归构造子类节点
        t->setLchild(buildKdTree(l, element));
        t->setRchild(buildKdTree(element + 1, r));
        return t;
    } //递归终止条件
    else if (Size == 1)
    {
        //最终节点维度设置为dim，方便后面查询判断
        auto *t = new TreeNode((*l).first, ((*l).first).size(), (*l).second);
        return t;
    }
    else
    {
        return nullptr;
    }
}

/**
 * @brief 删除Kd树
 * @param t
 */
void KdTree::deleteKdTree(TreeNode *t)
{
    assert(t);
    auto lchild = t->getLchild();
    auto rchild = t->getRchild();
    //左中右递归删除
    if (lchild)
        deleteKdTree(lchild);
    delete t;
    if (rchild)
        deleteKdTree(rchild);
}

/**
 * @brief 在Kd树中寻找值所属的子空间
 * @param path 所经过的节点
 * @param t Kd根节点
 * @param P 待寻找的数据
 */
void KdTree::findPath(vector<TreeNode *> &path, TreeNode *t, const vector<double> &P) const
{
    while (t)
    { //推入路径向量中
        path.emplace_back(t);
        //得到该节点划分维度
        uint axis = t->getAxis();
        const vector<double> coord = t->getCoordinate();
        //最终节点axis为dim
        if (axis < dim)
            t = (P[axis] < coord[axis] ? t->getLchild() : t->getRchild());
        else
            t = nullptr;
    }
}

/**
 * @brief 二分法插入新的邻近点<index,distance>
 * @param sorted_list k各邻近点有序序列(距离由小到大)
 * @param new_element 新邻近点
 */
void KdTree::orderInsert(vector<pair<uint, double>> &sorted_list, const pair<uint, double> &new_element)
{
    auto l = sorted_list.begin(), r = sorted_list.end();
    uint num = r - l;
    while (num)
    {
        if (new_element.second < (*(l + num / 2)).second)
        {
            r = l + num / 2;
        }
        else
        {
            l = l + num / 2 + 1;
        }
        num = r - l;
    }
    sorted_list.insert(r, new_element);
}

/**
 * @brief 搜寻给定数据邻近k个点
 * @note 末端节点直接比较，然后pop,非末端节点先计算该节点与P的距离,顺序插入后得末端最大dis，再比较其在axis维度与P的差值(即另一子空间与P的最小距离)，若
 *        小于dis,则另一子空间可能存在更小距离，若大于dis,则直接过滤掉另一子空间,pop当前节点
 * @param P 待搜寻数据
 * @param k 搜寻个数
 * @return pare<data_index,distance>向量
 */
vector<pair<uint, double>> KdTree::findKneighbor(const vector<double> &P, const uint &k) const
{
    assert(P.size() == dim);
    assert(k > 0);
    vector<TreeNode *> path; //路径所经过节点
    vector<pair<uint, double>> result;
    findPath(path, tree, P);
    while (!path.empty())
    {
        TreeNode *temp = path[path.size() - 1];
        uint axis = temp->getAxis();
        vector<double> coord = temp->getCoordinate();
        //末端子节点
        if (axis >= dim)
        {
            double distance_sq = Math::getDistance(coord, P);
            // result还未满并且距离小于结果中最大
            if (result.size() < k || distance_sq < result[result.size() - 1].second)
            {
                auto neighbor = make_pair(temp->getIndex(), distance_sq); //<index,distance>
                //插入该邻近点
                orderInsert(result, neighbor);
            }
            if (result.size() > k)
            {
                result.pop_back();
            }
            //遍历完path中末节点即直接pop
            path.pop_back();
        }
        //非末端节点
        else
        { // result未满时直接顺序插入
            if (result.size() < k)
            {
                double distance_sq = Math::getDistance(coord, P);
                orderInsert(result, make_pair(temp->getIndex(), distance_sq));
                path.pop_back();
                //以当前节点为根节点添加另一子空间进path
                findPath(path, P[axis] < coord[axis] ? temp->getRchild() : temp->getLchild(), P);
            }
            // result 满后比较再插入
            //在axis维度的差值小于result中最大的距离,继续寻找
            else if (result[result.size() - 1].second >= (P[axis] - coord[axis]) * (P[axis] - coord[axis]))
            {
                double distance_sq = Math::getDistance(coord, P);
                if (distance_sq < result[result.size() - 1].second)
                {
                    orderInsert(result, make_pair(temp->getIndex(), distance_sq));
                    result.pop_back();
                }
                path.pop_back();
                findPath(path, P[axis] < coord[axis] ? temp->getRchild() : temp->getLchild(), P);
            }
            //在axis维度的差值已经大于result中最大的距离,直接弹出
            else
            {
                path.pop_back();
            }
        }
    }
    return result;
}